#########
## Nature Human Behaviour 
## Leading Countries in Global Science Increasingly Receive More Citations than Other Countries Despite Doing Similar Research.
## https://doi.org/10.1038/s41562-022-01351-5
## Harvard Dataverse (Code and Metadata): https://doi.org/10.7910/DVN/WCOINR 
## Step 1B
## Data: Data_20210905
#########

# At a terminal, run:
# '''Note: '$1' is the discipline ID that you pass in.''' 
# ml python/3.6.1
# export TRANSFORMERS_CACHE=/scratch/groups/REDACTED/MAG_Field_Data/Huggingface/
# export GOOGLE_APPLICATION_CREDENTIALS="/home/users/REDACTED/Google_Translate/translation-309420-a1122a44fa1d.json"
# export PYTHONPATH=$GROUP_HOME/python/lib/python3.6/site-packages:$PYTHONPATH
# ml py-ipython/6.1.0_py36 python/3.6.1 py-scipy/1.1.0_py36 py-scikit-learn/0.19.1_py36 py-pandas/0.23.0_py36 gcc/10.1.0 py-pytorch/1.4.0_py36
# ml py-numpy/1.17.2_py36
# srun python3 -u Step_X1B_Python3_RR_MAG_Create_Journal_Censored_Nation_Label_to_Published_Year_Citation.py "$1" 


#############################
### Input
#############################
# Sociology - 144024400
# Applied Physics - 184779094
# Internal medicine - 126322002
# Test - '37037264'
# Test2 - '77088390'
import sys
discipline = sys.argv[1]

#############################
### Time Start
#############################
import time 
start_time = time.time()

#############################
### Modules
#############################
import os 

import bz2, re, string, random, numpy, scipy, math
import datetime
from optparse import OptionParser
from scipy.spatial.distance import pdist, squareform
from collections import Counter
from itertools import chain 
import networkx as nx
import itertools
import gzip
import time
import numpy as np
import itertools
import glob 
import pandas as pd
from math import radians, cos, sin, asin, sqrt
from math import log
from collections import Counter
import bz2 
import pickle
import gc 
from scipy.spatial.distance import cdist

from pyathena import connect
import os.path
from os import path

# Mulitprocessing
from multiprocessing.dummy import Pool as ThreadPool
import multiprocessing

from scipy.spatial.distance import cosine
from sklearn.metrics.pairwise import cosine_similarity

import boto3
import smart_open

#############################
### Output Filename 
#############################

Citation_Year_Filename = "OUTPUT_Python_MAG_RR_Journal_Censored_Citation_Country_by_Year_"+str(discipline)+".csv.gz"

#############################
### Functions
#############################
def fieldIDString(x,input_discipline):
	return x.replace("FIELDIDHERE", input_discipline)

def reduce_mem_usage(df):
    """ 
    iterate through all the columns of a dataframe and 
    modify the data type to reduce memory usage.        
    """
    start_mem = df.memory_usage().sum() / 1024**2
   
    for col in df.columns:
        col_type = df[col].dtype
        
        if col_type != object:
            c_min = df[col].min()
            c_max = df[col].max()
            if str(col_type)[:3] == 'int':
                if c_min > np.iinfo(np.int8).min and c_max <\
                  np.iinfo(np.int8).max:
                    df[col] = df[col].astype(np.int8)
                elif c_min > np.iinfo(np.int16).min and c_max <\
                   np.iinfo(np.int16).max:
                    df[col] = df[col].astype(np.int16)
                elif c_min > np.iinfo(np.int32).min and c_max <\
                   np.iinfo(np.int32).max:
                    df[col] = df[col].astype(np.int32)
                elif c_min > np.iinfo(np.int64).min and c_max <\
                   np.iinfo(np.int64).max:
                    df[col] = df[col].astype(np.int64)  
            else:
                if c_min > np.finfo(np.float16).min and c_max <\
                   np.finfo(np.float16).max:
                    df[col] = df[col].astype(np.float16)
                elif c_min > np.finfo(np.float32).min and c_max <\
                   np.finfo(np.float32).max:
                    df[col] = df[col].astype(np.float32)
                else:
                    df[col] = df[col].astype(np.float64)
        else:
            next    
    return df

def fastAthenaQuery(input_sql, input_chunksize = None):

    #This function will return a stream of the s3 file.
    # MUCH faster than PyAthena reading a few rows at a time via the API
    """:Return: a Pandas DataFrame of results from a `sql` query executed against AWS Athena."""
    cursor.execute(input_sql)
    #The s3_path should be of the format: '<bucket_name>/<file_path_inside_the_bucket>'
    #This is the full path with credentials:
    complete_s3_path = 'REDACTED' + cursor.output_location.split("REDACTED")[1]

    if input_chunksize == None:
        outputfile = pd.read_csv(smart_open.smart_open(complete_s3_path))
    else:
        outputfile = pd.read_csv(smart_open.smart_open(complete_s3_path),chunksize=input_chunksize)

    # Delete the file from the s3 bucket
    s3 = boto3.resource("REDACTED",
        aws_access_key_id='REDACTED',
        aws_secret_access_key= 'REDACTED')
    obj = s3.Object("REDACTED",
        complete_s3_path)
    obj.delete()

    return outputfile

#############################
### Read in Files
#############################

#############################
## GRID Database with GRID from Field
#############################

df_grid_database = pd.read_csv("addresses.csv")

grid_df = df_grid_database[["grid_id","country"]].rename(columns={'country':"Country"})
grid_df["Country"] = grid_df["Country"].str.lstrip().str.rstrip().str.upper().str.replace(" ","_")
grid_df = grid_df.dropna()

grid_df["Country"] = grid_df["Country"].replace(to_replace={'BAHAMAS_THE':'BAHAMAS',"KOREA_DEM_REP":"NORTH_KOREA","DEMOCRATIC_REPUBLIC_OF_THE_CONGO":"CONGO","SAINT_":"ST_","MACAO_SAR_PEOPLES_R_CHINA":"MACAO","CZECHIA":"CZECH_REPUBLIC"})

#############################
## GRID Database with GRID from Field
#############################

###############################
### Connections
###############################

cursor = connect(aws_access_key_id='REDACTED',
                 aws_secret_access_key='REDACTED',
                 s3_staging_dir='REDACTED',
                 region_name='REDACTED').cursor()

conn = connect(aws_access_key_id='REDACTED',
                 aws_secret_access_key='REDACTED',
                 s3_staging_dir='REDACTED',
                 region_name='REDACTED')

#############################
### Journal Censoring
#############################

journal_df_coverage_1980_1990 = pd.read_csv("OUTPUT_Python_MAG_Journal_Coverage_From_1980_to_1990.csv.gz").query("fieldofstudyid==@discipline")[["journalid","coverage_over_time"]]
journal_df_coverage_1990_2000 = pd.read_csv("OUTPUT_Python_MAG_Journal_Coverage_From_1990_to_2000.csv.gz").query("fieldofstudyid==@discipline")[["journalid","coverage_over_time"]]
journal_df_coverage_2000_2010 = pd.read_csv("OUTPUT_Python_MAG_Journal_Coverage_From_2000_to_2010.csv.gz").query("fieldofstudyid==@discipline")[["journalid","coverage_over_time"]]
journal_df_coverage_2010_2017 = pd.read_csv("OUTPUT_Python_MAG_Journal_Coverage_From_2010_to_2017.csv.gz").query("fieldofstudyid==@discipline")[["journalid","coverage_over_time"]]

# [2000-2017]
journal_df_coverage_since_2000 = journal_df_coverage_2000_2010.append(journal_df_coverage_2010_2017)

# [1980-2017]
journal_df_coverage_since_1980 = journal_df_coverage_1990_2000.append(journal_df_coverage_since_2000)
journal_df_coverage_since_1980 = journal_df_coverage_since_1980.append(journal_df_coverage_1980_1990)

# Combine Together <=1980 and <=2000
journal_df_coverage_since_1980 = journal_df_coverage_since_1980.groupby(["journalid"]).agg({"coverage_over_time":"sum"}).reset_index()
journal_df_coverage_since_1980 = journal_df_coverage_since_1980.sort_values(["coverage_over_time"],ascending=False)
journal_df_coverage_since_2000 = journal_df_coverage_since_2000.groupby(["journalid"]).agg({"coverage_over_time":"sum"}).reset_index()

journal_df_coverage_since_1980 = journal_df_coverage_since_1980[journal_df_coverage_since_1980.coverage_over_time == journal_df_coverage_since_1980.coverage_over_time.max()]["journalid"].reset_index(drop=True)
journal_df_coverage_since_2000 = journal_df_coverage_since_2000[journal_df_coverage_since_2000.coverage_over_time == journal_df_coverage_since_2000.coverage_over_time.max()]["journalid"].reset_index(drop=True)

censored_journals = "("+','.join(str(x) for x in journal_df_coverage_since_1980)+")"


###############################
### Queries
###############################

check_citation_year_label_query = '''
select true as returnTrue where EXISTS (SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_NAME = 'temp_rr_journal_censored_citation_edgelist_year_label_FIELDIDHERE');
'''

drop_citation_year_label_query = '''
drop table if exists mag_staging.temp_rr_journal_censored_citation_edgelist_year_label_FIELDIDHERE;
'''

citation_query_year_label = '''
create table mag_staging.temp_rr_journal_censored_citation_edgelist_year_label_FIELDIDHERE
WITH (
    format = 'TEXTFILE',
  	field_delimiter = '\t'
  )
as 
Select * 
FROM
      (Select  
      d.fieldofstudyid as Cited_FieldID,
      f.gridid as Citing_GridID,
      b.full_year as Citing_Year, 
      c.full_year as Cited_Year,
      e.gridid as Cited_GridID,
      count() as Number_of_Cites
From "mag_data"."paperreferences" a
left join "mag_data"."papers"  b
on a.paperid=b.paperid
left join "mag_data"."papers"  c
on a.paperreferenceid = c.paperid
left join "mag_data"."paperfieldsofstudy" d
on a.paperreferenceid = d.paperid 
left join (SELECT af.paperid, cf.gridid 
           FROM "mag_data"."paperauthoraffiliations" af
           LEFT JOIN "mag_data"."affiliations" bf
           ON af.affiliationid = bf.affiliationid 
           LEFT JOIN "mag_data"."grid" cf
           ON bf.gridid = cf.gridid
           WHERE bf.gridid<>'') f
on f.paperid = a.paperid 
left join (SELECT ae.paperid, ce.gridid 
           FROM "mag_data"."paperauthoraffiliations" ae
           LEFT JOIN "mag_data"."affiliations" be
           ON ae.affiliationid = be.affiliationid 
           LEFT JOIN "mag_data"."grid" ce
           ON be.gridid = ce.gridid
           WHERE be.gridid<>'') e
on e.paperid = a.paperreferenceid 
WHERE fieldofstudyid=FIELDIDHERE AND b.journalid in CENSOREDJOURNALSHERE AND c.journalid in CENSOREDJOURNALSHERE
GROUP BY 
b.full_year, 
d.fieldofstudyid,
c.full_year,
f.gridid,
e.gridid) Final
where Citing_GridID<>'' and Cited_GridID<>''
'''

citation_by_year_label_edgelist = '''
SELECT *
FROM "mag_staging"."temp_rr_journal_censored_citation_edgelist_year_label_FIELDIDHERE"
'''

##################################
#### Citations Labels with Year
##################################

if pd.read_sql(check_citation_year_label_query.replace('FIELDIDHERE',discipline), conn)["returnTrue"].empty == True:
	cursor.execute(citation_query_year_label.replace('CENSOREDJOURNALSHERE',censored_journals).replace('FIELDIDHERE',discipline))


df_citation_year_label_edgelist = fastAthenaQuery(citation_by_year_label_edgelist.replace('FIELDIDHERE',discipline),None)

cursor.execute(drop_citation_year_label_query.replace('FIELDIDHERE',discipline))

##### Note: Forgot to delete earlier citation lists
non_censored_citation_by_year_label_edgelist = '''
SELECT *
FROM "mag_staging"."temp_rr_citation_edgelist_year_label_FIELDIDHERE"
'''
try:
  cursor.execute(non_censored_citation_by_year_label_edgelist.replace('FIELDIDHERE',discipline))
except:
  print("No non-censored citation table left to delete.")

##################################
#### Merge with GRID and Aggregate
##################################

df_citation_year_label_edgelist = pd.merge(df_citation_year_label_edgelist,grid_df,
    left_on=["citing_gridid"],
    right_on=["grid_id"],
    how="left").rename(columns={"Country":"Country_Citing"}).drop(columns=["grid_id"])

df_citation_year_label_edgelist = pd.merge(df_citation_year_label_edgelist,grid_df,
    left_on=["cited_gridid"],
    right_on=["grid_id"],
    how="left").rename(columns={"Country":"Country_Cited"}).drop(columns=["grid_id"])

df_citation_year_label_edgelist = df_citation_year_label_edgelist.groupby(["Country_Citing","Country_Cited","citing_year","cited_year","cited_fieldid"]).agg({"number_of_cites":"sum"}).reset_index().rename(columns={"citing_year":"Citing_Year","cited_year":"Cited_Year","number_of_cites":"Number_of_Cites","cited_fieldid":"Discipline"})

##################################
#### Print Out CSV
##################################

df_citation_year_label_edgelist.to_csv(Citation_Year_Filename,index=False,compression='gzip')
